#!/usr/bin/env python3
"""
Evaluate token-level log-probs on a chunked-XXX dataset using vLLM.

Auto-detects each model’s context window (max_seq_len / max_position_embeddings).
Truncates prompts so that *after* the tokenizer adds its own BOS/EOS the
  final length ≤ max_model_len (handles Yi, Llama, Mistral, Pythia, …).
Supports data-parallel (DP) and tensor-parallel (TP) execution in vLLM.
"""

# ── std-lib ──────────────────────────────────────────────────────────────────
import os, sys, logging
from time import sleep
from multiprocessing import Process, Queue
from typing import List, Optional

# ── third-party ──────────────────────────────────────────────────────────────
import pandas as pd
import pyarrow as pa, pyarrow.parquet as pq
from transformers import AutoTokenizer, AutoConfig
from vllm import LLM, SamplingParams
from vllm.sequence import Logprob
import yaml

# ════════════════════════════════════════════════════════════════════════════
# Helper utilities
# ════════════════════════════════════════════════════════════════════════════
_SENTINEL_THRESHOLD = 1_000_000  # anything larger == “infinite/unknown”

def _safe_len(val) -> Optional[int]:
    """Return val if it’s a plausible context length, else None."""
    if val is None or val == float("inf"):
        return None
    try:
        iv = int(val)
        return iv if iv < _SENTINEL_THRESHOLD else None
    except Exception:
        return None

def get_tokenizer_and_max_len(model_name: str):
    """Return (tokenizer, max_model_len) with sensible fall-backs."""
    tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    cfg_len = None
    try:
        cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        cfg_len = _safe_len(getattr(cfg, "max_position_embeddings", None))
    except Exception:
        pass

    tok_len = _safe_len(getattr(tok, "model_max_length", None))

    max_len = cfg_len or tok_len or 2048
    if max_len == 2048 and cfg_len is None and tok_len is None:
        logging.warning(f"[{model_name}] context window unknown – using 2048")
    return tok, max_len

# ─── robust, model-agnostic truncation ───────────────────────────────────────
def safe_truncate(prompt: str, tok, max_len: int) -> str:
    """
    Trim *prompt* so that
        len(tokenizer.encode(prompt, add_special_tokens=True)) ≤ max_len
    Works for any Hugging Face tokenizer (Yi adds 4 tokens, Llama 2, etc.).
    """
    overhead = len(tok.encode("", add_special_tokens=True))
    budget   = max_len - overhead - 1         # leave an extra 1-token cushion
    if budget <= 0:
        raise ValueError("max_len smaller than tokenizer overhead!")

    ids = tok.encode(prompt, add_special_tokens=False)
    if len(ids) > budget:
        ids = ids[:budget]

    # loop to be absolutely certain after decode/encode round-trip
    while True:
        total_len = len(
            tok.encode(tok.decode(ids, skip_special_tokens=True),
                       add_special_tokens=True)
        )
        if total_len <= max_len:
            break
        ids = ids[:-1]        # drop one token and retry

    return tok.decode(
        ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    ).rstrip()

def load_vllm_config(path: str) -> dict:
    with open(path, "r") as f:
        return yaml.safe_load(f)

def process_logprobs(raw: List[Optional[dict[int, Logprob]]]) -> List[float]:
    return [
        next(iter(lp.values())).logprob
        for lp in raw[1:] if lp is not None
    ]

def configure_logging(verbose: bool = False):
    if not verbose:
        for n in (
            "vllm", "vllm.engine", "vllm.worker", "vllm.distributed",
            "transformers", "transformers.tokenization_utils_base", "torch",
        ):
            logging.getLogger(n).setLevel(logging.ERROR)
    logging.basicConfig(
        level=logging.INFO if verbose else logging.WARNING,
        format="%(levelname)s %(message)s",
    )

def get_slice(df: pd.DataFrame, dp: int, rank: int) -> pd.DataFrame:
    floor, rem = divmod(len(df), dp)
    start = rank * floor + min(rank, rem)
    end   = (rank + 1) * floor + min(rank + 1, rem)
    return df.iloc[start:end]

# ════════════════════════════════════════════════════════════════════════════
# Core building blocks
# ════════════════════════════════════════════════════════════════════════════
def build_llm(model: str, tp: int, cfg: dict) -> LLM:
    return LLM(model=model, tensor_parallel_size=tp, **cfg)

def make_prompts(rows, col, tok, max_len):
    return [safe_truncate(r[col], tok, max_len) for r in rows]

# ── single-process evaluation ────────────────────────────────────────────────
def eval_slice(model, tp, cfg, df, col, verbose=False):
    configure_logging(verbose)
    tok, max_len = get_tokenizer_and_max_len(model)
    rows    = df.to_dict("records")
    prompts = make_prompts(rows, col, tok, max_len)
    logging.info(f"[single] {len(prompts)} prompts")

    llm  = build_llm(model, tp, cfg)
    samp = SamplingParams(temperature=0.0, prompt_logprobs=0, max_tokens=1)

    out = llm.generate(prompts, samp)
    return [
        {
            "text_id": int(r["text_id"]),
            "chunk_id": int(r["chunk_id"]),
            "logprobs": process_logprobs(o.prompt_logprobs),
        }
        for r, o in zip(rows, out)
    ]

def run_single(model, tp, cfg, df, col, outfile, verbose=False):
    logging.info("Single-process evaluation …")
    schema = pa.schema([
        pa.field("text_id", pa.int32()),
        pa.field("chunk_id", pa.int32()),
        pa.field("logprobs", pa.list_(pa.float32())),
    ])
    writer = pq.ParquetWriter(outfile, schema)
    try:
        batch = eval_slice(model, tp, cfg, df, col, verbose)
        tbl   = pa.Table.from_pydict(
            {
                "text_id": [b["text_id"] for b in batch],
                "chunk_id": [b["chunk_id"] for b in batch],
                "logprobs": [b["logprobs"] for b in batch],
            }, schema=schema
        )
        writer.write_table(tbl)
        logging.info(f"Wrote {len(batch)} rows → {outfile}")
    finally:
        writer.close()

# ── DP worker ────────────────────────────────────────────────────────────────
def worker(model, dp, tp, rank, m_ip, m_port, cfg, df, col, q, verbose=False):
    configure_logging(verbose)
    os.environ.update(
        {
            "VLLM_DP_RANK": str(rank),
            "VLLM_DP_RANK_LOCAL": str(rank),
            "VLLM_DP_SIZE": str(dp),
            "VLLM_DP_MASTER_IP": m_ip,
            "VLLM_DP_MASTER_PORT": str(m_port),
        }
    )

    tok, max_len = get_tokenizer_and_max_len(model)
    slice_df = get_slice(df, dp, rank)
    rows     = slice_df.to_dict("records")
    prompts  = make_prompts(rows, col, tok, max_len)
    logging.info(f"[DP {rank}] {len(prompts)} prompts")

    llm  = build_llm(model, tp, cfg)
    samp = SamplingParams(temperature=0.0, prompt_logprobs=0, max_tokens=1)

    for r, o in zip(rows, llm.generate(prompts, samp)):
        q.put(
            {
                "text_id": int(r["text_id"]),
                "chunk_id": int(r["chunk_id"]),
                "logprobs": process_logprobs(o.prompt_logprobs),
            }
        )
    q.put(None)
    sleep(1)

def run_multi(model, dp, tp, cfg, df, col, outfile, verbose=False):
    logging.info(f"Launching {dp} DP workers")
    m_ip, m_port = "127.0.0.1", 8000
    q = Queue()

    schema = pa.schema([
        pa.field("text_id", pa.int32()),
        pa.field("chunk_id", pa.int32()),
        pa.field("logprobs", pa.list_(pa.float32())),
    ])
    writer = pq.ParquetWriter(outfile, schema)

    procs = []
    for r in range(dp):
        p = Process(
            target=worker,
            args=(model, dp, tp, r, m_ip, m_port, cfg, df, col, q, verbose),
        )
        p.start()
        procs.append(p)

    finished, batch = 0, []
    try:
        while finished < dp:
            item = q.get()
            if item is None:
                finished += 1
            else:
                batch.append(item)

            if len(batch) >= 1000 or (finished == dp and batch):
                tbl = pa.Table.from_pydict(
                    {
                        "text_id": [b["text_id"] for b in batch],
                        "chunk_id": [b["chunk_id"] for b in batch],
                        "logprobs": [b["logprobs"] for b in batch],
                    }, schema=schema
                )
                writer.write_table(tbl)
                batch.clear()
    finally:
        writer.close()

    err = 0
    for p in procs:
        p.join()
        if p.exitcode:
            err = p.exitcode
            logging.error(f"Worker {p.pid} exited {p.exitcode}")
    if err == 0:
        logging.info("All DP workers finished OK")
    return err

# ════════════════════════════════════════════════════════════════════════════
# CLI
# ════════════════════════════════════════════════════════════════════════════
def parse_args():
    import argparse
    ap = argparse.ArgumentParser("XXX perplexity eval (vLLM)")
    ap.add_argument("--model", default="XXX/XXX")
    ap.add_argument("--dp-size", type=int, default=2)
    ap.add_argument("--tp-size", type=int, default=1)
    ap.add_argument("--dataset", default="data/XXX/chunked_texts_df.parquet")
    ap.add_argument("--prompt-column", default="chunk_text")
    ap.add_argument("--config", default="vllm_config.yaml")
    ap.add_argument("--output-file", default="output/XXX.parquet")
    ap.add_argument("--debug-limit", type=int, default=0)
    ap.add_argument("--verbose", action="store_true")
    return ap.parse_args()

# ════════════════════════════════════════════════════════════════════════════
# Main
# ════════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
    args = parse_args()
    configure_logging(args.verbose)

    vllm_cfg = load_vllm_config(args.config)
    _, max_len = get_tokenizer_and_max_len(args.model)
    vllm_cfg.setdefault("max_model_len", max_len)

    df = pd.read_parquet(args.dataset)
    if args.debug_limit:
        df = df.head(args.debug_limit)
        logging.info(f"[DEBUG] limited to {len(df)} rows")

    if args.dp_size == 1:
        run_single(
            args.model, args.tp_size, vllm_cfg,
            df, args.prompt_column, args.output_file,
            args.verbose,
        )
    else:
        code = run_multi(
            args.model, args.dp_size, args.tp_size, vllm_cfg,
            df, args.prompt_column, args.output_file,
            args.verbose,
        )
        sys.exit(code)
